import torch
import matplotlib.pyplot as plt


def position_encoding(seq_len, embedding_dims):
    # 位置编码的信息矩阵
    P = torch.zeros((seq_len, embedding_dims))  # (1000,500)
    # 取出最后一个维度值embedding_dims
    d_model = embedding_dims

    x = torch.arange(seq_len).unsqueeze(-1) / torch.pow(10000, torch.arange(0, embedding_dims, 2) / d_model)
    P[:, 0::2] = torch.sin(x)
    P[:, 1::2] = torch.cos(x)

    plt.imshow(P.detach().numpy())
    plt.show()


position_encoding(1000, 500)
